import torch
import torch.nn as nn
import numpy as np
from rlcore.distributions import Categorical
import torch.nn.functional as F
import math
from multiagent.scenarios.hunter_invader_para import ScenePara


def Norm(x):# 求长度
    return np.linalg.norm(x)

class INVADER_Control(nn.Module):
    def __init__(self, action_space, num_agents, num_entities, input_size=16, hidden_dim=128, embed_dim=None,
                 pos_index=2, norm_in=False, nonlin=nn.ReLU, n_heads=1, mask_dist=None, entity_mp=False, device = None):
        super().__init__()
        self.num_agents = num_agents



    def forward(self, inp, state, mask=None):
        raise NotImplementedError


    def act(self, inp, state, mask=None, deterministic=False):
        self.thread_num = inp.shape[0] // self.num_agents
        assert self.thread_num == inp.shape[0] / self.num_agents

        
        # inp
        # （线程数*智能体数 = 第一维度 a1e1,a1e2,a1e3,...,a2e1,a2e2,a2e3,....  ,  obs维度 = 第二维度）

        obs = inp.view(self.num_agents,self.thread_num,-1)
        # （线程，智能体，obs维度）
        obs = obs.permute(1,0,2)

        assert ScenePara.num_landmarks == 1 # 埋桩


        # action.shape = (线程，智能体，1)
        action = torch.ones(size=(obs.shape[0],obs.shape[1],1)) * -1

        for thread in range(self.thread_num):
            for agent in range(self.num_agents):
                obs_agent = obs[thread,agent,:]
                landmark_relative_pos = obs_agent[4:6].cpu().numpy()
                relative_x = landmark_relative_pos[0]
                relative_y = landmark_relative_pos[1]
                normed_vec = landmark_relative_pos/ (Norm(landmark_relative_pos)+0.001)
                up = np.sum(normed_vec*[0,1])
                dn = np.sum(normed_vec*[0,-1])
                ri = np.sum(normed_vec*[1,0])
                le = np.sum(normed_vec*[-1,0])
                dot_product = np.array([up,dn,ri,le])
                direct = np.argmax(dot_product)
                # NOOP[0], UP[1], RIGHT[2], DOWN[3], LEFT[4]
                if direct == 0:
                    action[thread,agent,0] = 4  
                elif direct == 1:
                    action[thread,agent,0] = 3
                elif direct == 2:
                    action[thread,agent,0] = 2
                elif direct == 3:
                    action[thread,agent,0] = 1

        assert not (action == -1).any()
        action = action.permute(1,0,2)
        action = action.reshape(self.thread_num*self.num_agents,1)

        value = torch.zeros_like(action)
        action_log_probs = torch.zeros_like(action)
        state = torch.zeros_like(action)
        return value,action,action_log_probs,state
        # props[1] = action list = torch.Size([128=线程数*智能体数, 1=整数标量动作0,1,2,3,4])

    def get_value(self, inp, state, mask):
        return torch.zeros_like(state)